//----------------------------------------------------------------------------
//
// Copyright (C) Sartorius Stedim Data Analytics AB 2017 -
//
// Use, modification and distribution are subject to the Boost Software
// License, Version 1.0. (See http://www.boost.org/LICENSE_1_0.txt)
//
//----------------------------------------------------------------------------

#include <string.h>
#include <stdlib.h>
#include "sqprunner.h"
#include "utf8util.h"

int SQPRunner_Init(SQPRunner* pObj, char* szUSPName, char* szPredsetNames, char* szPluginPath)
{
   pObj->mszUSPName = (char*)strdup(szUSPName);
   pObj->mszPredsetNames = NULL;
   pObj->mszPluginPath = NULL;
   pObj->mProjectHandle = NULL;
   if (szPredsetNames)
   {
      pObj->mszPredsetNames = (char*)strdup(szPredsetNames);
   }
   if (szPluginPath)
   {
      pObj->mszPluginPath = (char*)strdup(szPluginPath);
   }
   return 1;
}

void SQPRunner_Destroy(SQPRunner* pObj)
{
   free(pObj->mszUSPName);
   if (pObj->mszPredsetNames)
   {
      free(pObj->mszPredsetNames);
   }
   if (pObj->mszPluginPath)
   {
      free(pObj->mszPluginPath);
   }
}

void SQPRunner_Run(SQPRunner* pObj, FILE* pOut, FILE* pErr)
{
   char* szErrString;                  /* String that will contain possible error. */
   SQ_StringVector  oFileNames = NULL;          /* Names of the files to import. */
   SQ_FileReader oFileHandle = NULL;
   SQ_FileReader_Specification oSpecification;
   SQ_Prediction pPrediction = NULL;
   SQ_VectorData pVectorData = NULL;
   SQ_Model pModel = NULL;
   SQ_Bool bValid;
   int iModelIndex = 1;
   int iModelNumber;

   if (SQ_IsLicenseFileValid(&bValid) != SQ_E_OK)
   {
      szErrString = "Could not read license.";
      goto error_exit;
   }
   else if (!bValid)
   {
      szErrString = "Invalid license.";
      goto error_exit;
   }

   SQ_InitStringVector(&oFileNames, 1);
   SQ_SetStringInVector(oFileNames, 1, pObj->mszPredsetNames);

   /* Load the .usp file. */
   if (SQ_FAILED(SQ_OpenProject(pObj->mszUSPName, NULL /* Not a password protected project*/, &pObj->mProjectHandle)))
   {
      szErrString = "Could not open project. OpenProject failed.";
      goto error_exit;
   }
   if (SQ_FAILED(SQ_GetModelNumberFromIndex(pObj->mProjectHandle, iModelIndex, &iModelNumber)))
   {
      szErrString = "SQ_GetModelNumberFromIndex failed.";
      goto error_exit;
   }
   if (SQ_FAILED(SQ_GetModel(pObj->mProjectHandle, iModelNumber, &pModel)))
   {
      szErrString = "SQ_GetModel failed.";
      goto error_exit;
   }

   /* Load the prediction set file. */
   if (SQ_FAILED(SQ_OpenFileReader(oFileNames, pObj->mszPluginPath, NULL, &oFileHandle)))
   {
      szErrString = "OpenFileReader failed. Possible couses:\n\t1: FileReader is not allowed by license file\n\t2: Corrupted Prediction set file.";
      goto error_exit;
   }
   /* Specify the format of the file, the data starts at the first column,
      the primary variable names are found at the first row and
      the data starts at the second row. */
   oSpecification.miFirstDataColumn = 1;
   oSpecification.miFirstDataRow = 2;
   oSpecification.miPrimaryVariableIDRow = 1;
   oSpecification.mpvecExcludedRows = NULL;

   /* Perform the prediction. */
   if (SQ_FAILED(SQ_FileReader_PredictFromFile(pModel, oFileHandle, &oSpecification, &pPrediction)))
   {
      szErrString = "PredictFromFile failed.";
      goto error_exit;
   }

   SQ_ClearStringVector(&oFileNames);
   SQ_CloseFileReader(&oFileHandle);

   /* Get XVarPS */
   if (SQ_FAILED(SQ_GetXVarPS(pPrediction, SQ_Unscaled_Default, SQ_Backtransformed_Default, NULL, &pVectorData)))
   {
      szErrString = "GetXVarPS failed.";
      goto error_exit;
   }
   else
   {
      UTF8_printf(pOut, "XVarPS:\n");
      SQPRunner_PrintVectorData(pVectorData, pOut);
      SQ_ClearVectorData(&pVectorData);
   }

   /* Clean up the predictions. */
   SQ_ClearPrediction(&pPrediction);
   SQ_CloseProject(&pObj->mProjectHandle);
   return;
error_exit:
   UTF8_printf(pErr, "%s\n", szErrString);
   SQ_CloseProject(&pObj->mProjectHandle);
}

void SQPRunner_PrintVectorData(SQ_VectorData pVectorData, FILE* pOut)
{
   SQ_FloatMatrix pMatrix = 0;
   SQ_StringVector pRowNames = 0;
   SQ_StringVector pColumnNames = 0;
   float fVal;
   int iColIter;
   int iRowIter;
   int iRows = 0;
   int iCols = 0;
   char strVal[1000];

   if (SQ_GetDataMatrix(pVectorData, &pMatrix) != SQ_E_OK)
   {
      UTF8_printf(pOut, "SQ_GetDataMatrix failed.");
      return;
   }

   if (SQ_GetNumColumnsInFloatMatrix(pMatrix, &iCols) != SQ_E_OK)
   {
      UTF8_printf(pOut, "SQ_GetNumColumnsInFloatMatrix failed.");
      return;
   }
   if (SQ_GetNumRowsInFloatMatrix(pMatrix, &iRows) != SQ_E_OK)
   {
      fprintf(pOut, "SQ_GetNumRowsInFloatMatrix failed.");
      return;
   }
   if (SQ_GetRowNames(pVectorData, &pRowNames) != SQ_E_OK)
   {
      UTF8_printf(pOut, "SQ_GetRowNames failed.");
      return;
   }
   if (SQ_GetColumnNames(pVectorData, &pColumnNames) != SQ_E_OK)
   {
      fprintf(pOut, "SQ_GetColumnNames failed.");
      return;
   }

   SQPRunner_PrintStringVector(pColumnNames, '\t', pOut);
   for (iRowIter = 0; iRowIter < iRows; ++iRowIter)
   {
      /* Print the row name */
      if (SQ_GetStringFromVector(pRowNames, iRowIter + 1, strVal, 1000) != SQ_E_OK)
      {
         UTF8_printf(pOut, "SQ_GetStringFromVector failed.");
         return;
      }
      UTF8_printf(pOut, "%s\t", strVal);

      for (iColIter = 0; iColIter < iCols; ++iColIter)
      {
         /* Print the results tab separated to the file */
         if (SQ_GetDataFromFloatMatrix(pMatrix, iRowIter + 1, iColIter + 1, &fVal) != SQ_E_OK)
         {
            UTF8_printf(pOut, "SQ_GetDataFromFloatMatrix failed.");
            return;
         }

         UTF8_printf(pOut, "%.8f\t", fVal);
      }
      UTF8_printf(pOut, "\n");
   }
   UTF8_printf(pOut, "\n");
}

/* Function to print a string vector */
void SQPRunner_PrintStringVector(SQ_StringVector pStringVector, const char szSeparator, FILE* pOut)
{
   char strVal[1000];
   int iIter;
   int iNumStrings = 0;

   if (SQ_GetNumStringsInVector(pStringVector, &iNumStrings) != SQ_E_OK)
   {
      UTF8_printf(pOut, "SQ_GetNumStringsInVector failed.");
      return;
   }

   for (iIter = 1; iIter <= iNumStrings; ++iIter)
   {
      /* Print the results */
      if (SQ_GetStringFromVector(pStringVector, iIter, strVal, sizeof(strVal)) != SQ_E_OK)
      {
         UTF8_printf(pOut, "SQ_GetStringFromVector failed.");
         return;
      }
      UTF8_printf(pOut, "%s%c", strVal, szSeparator);
   }
   UTF8_printf(pOut, "\n");
}

